{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Importing and hosting an ONNX model with MXNet\n",
    "\n",
    "The [Open Neural Network Exchange](https://onnx.ai/) (ONNX) is an open format for representing deep learning models with its extensible computation graph model and definitions of built-in operators and standard data types.\n",
    "\n",
    "In this example, we will use the Super Resolution model from [Image Super-Resolution Using Deep Convolutional Networks](https://ieeexplore.ieee.org/document/7115171), where Dong et al. trained a model for taking a low-resolution image as input and producing a high-resolution one. This model, along with many others, can be found at the [ONNX Model Zoo](https://github.com/onnx/models).\n",
    "\n",
    "We will use the SageMaker Python SDK to host this ONNX model in SageMaker, and perform inference requests."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "First, we'll get the IAM execution role from our notebook environment, so that SageMaker can access resources in your AWS account later in the example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "isConfigCell": true
   },
   "outputs": [],
   "source": [
    "from sagemaker import get_execution_role\n",
    "\n",
    "role = get_execution_role()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The hosting script\n",
    "\n",
    "We'll need to provide a hosting script that can run on the SageMaker platform. This script will be invoked by SageMaker when we perform inference.\n",
    "\n",
    "The script we're using here implements two functions:\n",
    "\n",
    "* `model_fn()` - the SageMaker model server uses this function to load the model\n",
    "* `transform_fn()` - this function is for using the model to take the input and produce the output\n",
    "\n",
    "The script here is an adaptation of the [ONNX Super Resolution example](https://github.com/apache/incubator-mxnet/blob/master/example/onnx/super_resolution.py) provided by the [Apache MXNet](https://mxnet.incubator.apache.org/) project."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pygmentize super_resolution.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparing the model\n",
    "\n",
    "To create a SageMaker Endpoint, we'll first need to prepare the model to be used in SageMaker."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Downloading the model\n",
    "\n",
    "For this example, we will use a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models), where you can find a collection of pre-trained models to work with. Here, we will download the [Super Resolution](https://github.com/onnx/models#super-resolution) model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://onnx-mxnet.s3.amazonaws.com/examples/super_resolution.onnx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compressing the model data\n",
    "\n",
    "Now that we have the model data locally, we will need to compress it and upload the tarball to S3 for the SageMaker Python SDK to create a Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tarfile\n",
    "\n",
    "from sagemaker.session import Session\n",
    "\n",
    "with tarfile.open('onnx_model.tar.gz', mode='w:gz') as archive:\n",
    "    archive.add('super_resolution.onnx')\n",
    "\n",
    "model_data = Session().upload_data(path='onnx_model.tar.gz', key_prefix='model')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating a SageMaker Python SDK Model instance\n",
    "\n",
    "With the model data uploaded to S3, we now have everything we need to instantiate a SageMaker Python SDK Model. We'll provide the constructor the following arguments:\n",
    "\n",
    "* `model_data`: the S3 location of the model data\n",
    "* `entry_point`: the script for model hosting that we looked at above\n",
    "* `role`: the IAM role used\n",
    "* `framework_version`: the MXNet version in use\n",
    "\n",
    "You can read more about creating an `MXNetModel` object in the [SageMaker Python SDK API docs](https://sagemaker.readthedocs.io/en/latest/sagemaker.mxnet.html#mxnet-model)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.mxnet import MXNetModel\n",
    "\n",
    "mxnet_model = MXNetModel(model_data=model_data,\n",
    "                         entry_point='super_resolution.py',\n",
    "                         role=role,\n",
    "                         py_version='py3',\n",
    "                         framework_version='1.3.0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating an Endpoint\n",
    "\n",
    "Now we can use our `MXNetModel` object to build and deploy an `MXNetPredictor`. This creates a SageMaker Model and Endpoint, the latter of which we can use for performing inference. \n",
    "\n",
    "The arguments to the `deploy()` function allow us to set the number and type of instances that will be used for the Endpoint. Here we will deploy the model to a single `ml.m4.xlarge` instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "predictor = mxnet_model.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performing inference\n",
    "\n",
    "With our Endpoint deployed, we can now send inference requests to it. We'll use one image as an example here."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing the image\n",
    "\n",
    "First, we'll download the image (and view it)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image as Img\n",
    "from mxnet.test_utils import download\n",
    "\n",
    "img_name = 'super_res_input.jpg'\n",
    "img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/{}'.format(img_name)\n",
    "download(img_url, img_name)\n",
    "\n",
    "Img(filename=img_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll resize it to be 224x224 pixels. In addition, we'll use a grayscale version of the image (or, more accurately, taking the 'Y' channel after converting it to [YCbCr](https://en.wikipedia.org/wiki/YCbCr)) to match the images that were used for training the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "input_image_dim = 224\n",
    "img = Image.open(img_name).resize((input_image_dim, input_image_dim))\n",
    "\n",
    "img_ycbcr = img.convert('YCbCr')\n",
    "img_y, img_cb, img_cr = img_ycbcr.split()\n",
    "input_image = np.array(img_y)[np.newaxis, np.newaxis, :, :]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sending the inference request\n",
    "\n",
    "We'll now call `predict()` on our predictor to use our model to create a bigger image from the input image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = predictor.predict(input_image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Viewing the result\n",
    "\n",
    "Now we'll look at the resulting image from our inference request. First we'll convert it and save it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "img_out_y = Image.fromarray(np.uint8(np.asarray(out)), mode='L')\n",
    "result_img = Image.merge('YCbCr', [img_out_y,\n",
    "                         img_cb.resize(img_out_y.size, Image.BICUBIC),\n",
    "                         img_cr.resize(img_out_y.size, Image.BICUBIC)]).convert(\"RGB\")\n",
    "output_img_dim = 672\n",
    "assert result_img.size == (output_img_dim, output_img_dim)\n",
    "\n",
    "result_img_file = 'output.jpg'\n",
    "result_img.save(result_img_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we'll look at the image itself. We can see that it is indeed a larger version of the image we started with."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Img(filename=result_img_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For comparison, we can look at the original image simply resized, without using the model. The lack of detail in this version is especially noticeable with the dog's fur."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "naive_output = Image.open(img_name).resize((output_img_dim, output_img_dim))\n",
    "\n",
    "naive_output_file = 'naive_output.jpg'\n",
    "naive_output.save(naive_output_file)\n",
    "\n",
    "Img(naive_output_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Deleting the Endpoint\n",
    "\n",
    "Since we've reached the end, we'll delete the SageMaker Endpoint to release the instance associated with it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.delete_endpoint()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_mxnet_p36",
   "language": "python",
   "name": "conda_mxnet_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  },
  "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
